# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2021 Toyota Research Institute.  All rights reserved.
import json
import logging
import os
from collections import OrderedDict, defaultdict

import numpy as np
import torch

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data import detection_utils as d2_utils
from detectron2.structures import Boxes, BoxMode, Instances
from detectron2.utils.visualizer import ColorMode, Visualizer

DETECTION_RESULT_FILE = "coco_instances_results.json"
SEMSEG_RESULT_FILE = "sem_seg_predictions.json"

D2_COLORMODE_MAPPING = {
    "image": ColorMode.IMAGE,
    "segm": ColorMode.SEGMENTATION,
    "image_bw": ColorMode.IMAGE_BW,
}

LOG = logging.getLogger(__name__)


def get_tasks_from_cfg(cfg):
    tasks = []
    if cfg.MODEL.BOX2D_ON:
        tasks.append('bbox2d')
    assert len(tasks) > 0, "Empty task."
    return tasks


def create_instances(predictions, image_size, score_threshold, metadata, score_key="score"):
    ret = Instances(image_size)

    # score = np.asarray([x["score"] for x in predictions])
    score = np.asarray([x[score_key] for x in predictions])
    chosen = (score > score_threshold).nonzero()[0]
    score = score[chosen]
    bbox = np.asarray([predictions[i]["bbox"] for i in chosen]).reshape(-1, 4)
    bbox = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)

    # dataset_id_map = metadata.thing_dataset_id_to_contiguous_id
    if not hasattr(metadata, 'thing_dataset_id_to_contiguous_id'):
        # (dennis.park) Assume the `category_id` is already a contiguous IDs starting at 0.
        dataset_id_map = {idx: idx for idx, _ in enumerate(metadata.thing_classes)}
    else:
        dataset_id_map = metadata.thing_dataset_id_to_contiguous_id
    labels = np.asarray([dataset_id_map[predictions[i]["category_id"]] for i in chosen])

    ret.scores = score
    ret.pred_boxes = Boxes(bbox)
    ret.pred_classes = labels

    # Add bbox3d
    try:
        ret.pred_boxes3d = torch.as_tensor([predictions[i]["bbox3d"] for i in chosen])
    except KeyError:
        pass
    return ret


class D2PredictionVisualizer():
    """
    Adapted from detectron2:
        detectron2.utils.visualizer

    Key difference: load inference results on disk generated by COCOEvaluator
    """
    def __init__(self, cfg, dataset_name, inference_output_dir):
        self._metadata = MetadataCatalog.get(dataset_name)
        self._input_format = cfg.INPUT.FORMAT
        self._scale = cfg.VIS.D2.PREDICTIONS.SCALE
        self._d2_viz_color_mode = D2_COLORMODE_MAPPING[cfg.VIS.D2.PREDICTIONS.COLOR_MODE]

        tasks = get_tasks_from_cfg(cfg)
        dataset_dicts = DatasetCatalog.get(dataset_name)

        # Per-image predicted instances
        self.pred_instances_by_image = None
        if "bbox2d" in tasks:
            with open(os.path.join(inference_output_dir, DETECTION_RESULT_FILE), 'r') as f:
                instance_predictions = json.load(f)

            pred_instances_by_image = defaultdict(list)
            for p in instance_predictions:
                # 'p' is key'ed by 'image_id'.
                image_id = p['image_id']
                pred_instances_by_image[image_id].append(p)

            # det2d_threshold = cfg.VIS.PREDICTIONS.DET2D_THRESHOLD
            det2d_threshold = cfg.VIS.D2.PREDICTIONS.THRESHOLD
            # This handles images with no predictions.
            for dataset_dict in dataset_dicts:
                image_id = dataset_dict['image_id']
                img_shape = (dataset_dict['height'], dataset_dict['width'])
                pred_instances_by_image[image_id] = create_instances(
                    pred_instances_by_image[image_id], img_shape, det2d_threshold, self._metadata
                )

            self.pred_instances_by_image = pred_instances_by_image
            LOG.info(
                f"Found 2D detection predictions (bbox2d and/or mask2d) for {len(pred_instances_by_image)} images."
            )

    def visualize(self, x):
        """
        Parameters
        ----------
        x: Dict
            One 'dataset_dict'.

        Returns
        -------
        viz_images: Dict[np.array]
            Visualizations as RGB images.
        """
        # Load image.
        img = d2_utils.read_image(x["file_name"], format=self._input_format)
        img = d2_utils.convert_image_to_rgb(img, self._input_format)

        viz_images = OrderedDict()

        # d2 groundtruth instances viz (2D box, mask, keypoints)
        if 'annotations' in x:
            # Visualizer.draw_datset_dict() renders various types of annotations.
            # But here we only use its capability to render *instance() annotations.
            _x = {'annotations': x['annotations']}
            viz = Visualizer(img, self._metadata, scale=self._scale, instance_mode=self._d2_viz_color_mode)
            viz_image = viz.draw_dataset_dict(_x).get_image()
            viz_images["viz_gt_instances_d2"] = viz_image

        # d2 instance predictions viz (2D box, mask, keypoints)
        if self.pred_instances_by_image is not None:
            pred_instances = self.pred_instances_by_image[x['image_id']]
            viz = Visualizer(img, self._metadata, scale=self._scale, instance_mode=self._d2_viz_color_mode)
            viz_image = viz.draw_instance_predictions(pred_instances).get_image()
            viz_images["viz_pred_instance_d2"] = viz_image

        return viz_images


def draw_gt_instances_d2(gt_instances, img, metadata, scale, instance_mode):
    """Wrapper of D2's 'Visualizer.draw_instance_predictions()' to render GT instances.
    """
    # Rename instance fields to work with Visualizer.draw_instance_predictions() of detectron2.
    field_remapping = {
        'gt_boxes': 'pred_boxes',
        'gt_classes': 'pred_classes',
    }
    fields = {}
    for k, v in gt_instances._fields.items():
        new_k = field_remapping.get(k, None)
        k = new_k or k
        fields[k] = v

    instances = Instances(image_size=gt_instances._image_size, **fields)
    viz = Visualizer(img, metadata, scale=scale, instance_mode=instance_mode)
    viz_image = viz.draw_instance_predictions(instances).get_image()
    return viz_image


class D2DataloaderVisualizer():
    def __init__(self, cfg, dataset_name):
        self._metadata = MetadataCatalog.get(dataset_name)
        self._input_format = cfg.INPUT.FORMAT
        self._scale = cfg.VIS.D2.DATALOADER.SCALE
        self._d2_viz_color_mode = D2_COLORMODE_MAPPING[cfg.VIS.D2.DATALOADER.COLOR_MODE]

    def visualize(self, x):
        # Assumption: dataloader produce CHW images.
        img = d2_utils.convert_image_to_rgb(x['image'].permute(1, 2, 0), self._input_format)

        viz_images = OrderedDict()

        # d2 instance viz (2D box, mask, keypoints)
        gt_instances = x['instances']
        viz_image = draw_gt_instances_d2(gt_instances, img, self._metadata, self._scale, self._d2_viz_color_mode)
        viz_images['viz_gt_instances_d2'] = viz_image

        return viz_images
